Iblpupil jax backend#52
Conversation
There was a problem hiding this comment.
left a few small comments. one more substantial comment is that there are a few functions that duplicate code and can be condensed into a single function (smooth_min, inner_smooth_min_routine, pupil_smooth_final):
def pupil_smooth(y, smooth_params, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var, return_lls_only: False):
# Construct state transition matrix
diameter_s = smooth_params[0]
com_s = smooth_params[1]
A = jnp.array([
[diameter_s, 0, 0],
[0, com_s, 0],
[0, 0, com_s]
])
# cov_matrix
Q = jnp.array([
[diameters_var * (1 - (A[0, 0] ** 2)), 0, 0],
[0, x_var * (1 - A[1, 1] ** 2), 0],
[0, 0, y_var * (1 - (A[2, 2] ** 2))]
])
if return_lls_only:
# Run filtering with the current smooth_param
_, _, nll = jax_forward_pass(y, m0, S0, A, Q, C, R, ensemble_var)
else:
# Run filtering and smoothing with the current smooth_param
mf, Vf, nll, nll_array = jax_forward_pass_nlls(y, m0, S0, A, Q, C, R, ensemble_vars)
ms, Vs = jax_backward_pass(mf, Vf, A, Q)
return ms, Vs, nll_array
This means A, Q are only initialized once, and makes clearer the two different modes of filtering (for nll computation) and smoothing
|
Another suggestion: in a future PR it would be good to either unify |
No description provided.